import torch.nn as nn
import torch.nn.functional as F
import torch
from Utils import TokenList, TokenIndex, NodeIndex, fileIO
from random import shuffle

class GraphOps(object):

    @staticmethod
    def process_graph(graph):

        graph = [list([edge[0],edge[-1]]) for edge in graph]
        graph_nodes = list(set(sum([edge for edge in graph],[])))
        graph_node_idx = NodeIndex()
        n_graph_nodes = len(graph_nodes)

        for i in range(n_graph_nodes):
            graph_node = graph_nodes[i]
            graph_node_idx.add(graph_node,i)
        graph_matrix = torch.zeros(n_graph_nodes,n_graph_nodes)

        for edge in graph:
            v1_idx = graph_node_idx.get(edge[0])
            v2_idx = graph_node_idx.get(edge[1])
            graph_matrix[v1_idx,v2_idx] = 1.0
    
        return graph_matrix, n_graph_nodes

class KInspector(nn.Module):

    def __init__(self):

        super().__init__()
        self.node_embeddings = nn.Embedding(max_nodes,kge_size)
        self.ffns = nn.ModuleList([nn.Linear(max_nodes,max_nodes) for _ in range(kgn_layers)])

    def forward(self,G,nG,rep=None):

        n_es = self.node_embeddings(torch.arange(nG))
        n_es = torch.row_stack([e+rep for e in n_es])
        node_matrix = torch.zeros(max_nodes,max_nodes)

        for i in range(nG):
            for j in range(nG):
                node_matrix[i][j] = (n_es[i] @ n_es[j])

        for ffn in self.ffns:
            node_matrix = F.leaky_relu(ffn(node_matrix))
        
        return torch.sigmoid(node_matrix)
    
    def evaluate(self):
        
        loss = 0.0
        threshold = 0.5
        compare_counter = 0
        true_counter = 0

        for datapoint in eval_data:
            KG = datapoint[0][1]['KG']
            G, nG = GraphOps.process_graph(KG)
            kg_a = self(G,nG)
            gt = torch.zeros(max_nodes,max_nodes)
            for i in range(nG):
                for j in range(nG):
                    gt[i][j] = G[i][j]
            loss += torch.mean(torch.pow(gt-kg_a,2))

            for i in range(nG):
                for j in range(nG):
                    if gt[i][j] == int(kg_a[i][j] > threshold):
                        true_counter += 1
                        compare_counter += 1
                    elif gt[i][j] != int(kg_a[i][j] > threshold):
                        compare_counter += 1

        return loss, true_counter/float(compare_counter)

    def train(self,epochs = 10):

        opt = torch.optim.AdamW(self.parameters())
        
        for epoch in range(epochs):
            shuffle(CLEVRER_data)
            for datapoint in CLEVRER_data:
                KG = datapoint[0][1]['KG']
                G, nG = GraphOps.process_graph(KG)
                kg_a = self(G,nG)
                gt = torch.zeros(max_nodes,max_nodes)
                for i in range(nG):
                    for j in range(nG):
                        gt[i][j] = G[i][j]
                loss = torch.sum(torch.pow(gt-kg_a,2))
                loss.backward()
                opt.step()
                opt.zero_grad()
                eval_loss, eval_acc = self.evaluate()
                print (loss.item())
                print ('eval metrics (loss,acc)', eval_loss,eval_acc)

class TextTokenizer(object):

    def __init__(self,data):

        global c_size

        self.tokens = TokenList()
        self.token_index = TokenIndex()
        self.n_tokens = 0

        for datapoint in data:
            x = datapoint[0]
            question = x[0]
            prg = datapoint[1]
            datapoint_tokens = question.split(' ')+prg
            n_datapoint_tokens = len(datapoint_tokens)
            n_END_tokens = 1
            c_size += n_datapoint_tokens + n_END_tokens
            self.tokens += datapoint_tokens

        self.tokens = list(set(self.tokens))
        self.n_tokens = len(self.tokens)

        for n in range(self.n_tokens):
            token = self.tokens[n]
            self.token_index.add(token,n)

    def encode(self,inp):

        tokens = inp.split(' ')
        return [self.token_index.get(token) for token in tokens]
    
    def decode(self,token_indices):

        return [self.tokens[token_idx] for token_idx in token_indices]

class DataOps(object):

    @staticmethod
    def process_datapoint(datapoint,tk=None):

        X, Y = [],[]
        x = datapoint[0]
        question = x[0]
        prg = datapoint[1]
        q_encoding = tk.encode(question)
        prg_encoding = tk.encode(' '.join(prg))
        datapoint_encoding = q_encoding + prg_encoding
        datapoint_len = len(datapoint_encoding)
        for t_idx in range(datapoint_len-1):
            x_sub = datapoint_encoding[:t_idx+1]
            y_sub = datapoint_encoding[t_idx+1]
            X += [x_sub]; Y += [y_sub]

        return X, Y

class TextGen(nn.Module):

    def __init__(self):

        super().__init__()
        self.embeddings = nn.Embedding(n_tokens,e_size)
        self.pos_embeddings = nn.Embedding(c_size,e_size)
        self.ffn_i = nn.Linear(e_size,h_size)
        self.ffn_H = nn.ModuleList([nn.Linear(h_size,h_size) for _ in range(n_layers)])
        self.ffn_o = nn.Linear(h_size,e_size)
        self.head = nn.Linear(e_size,n_tokens)

    def forward(self,X):

        logits = []
        for x in X:
            nx_tokens = len(x)
            e_x = self.embeddings(torch.tensor(x))
            p_e = self.pos_embeddings(torch.arange(nx_tokens))
            e_x += p_e
            e_x = F.leaky_relu(self.ffn_i(e_x))
            for ffn_h in self.ffn_H:
                e_x = F.leaky_relu(ffn_h(e_x))
            e_x = F.leaky_relu(self.ffn_o(e_x))
            rep = torch.mean(e_x,dim=0)
            e_x = self.head(e_x)
            logit_set = e_x[-1]
            logits.append(logit_set)

        return rep, torch.row_stack(logits)
    
    def train(self,epochs = 10):

        opt = torch.optim.AdamW(self.parameters())
        CE = nn.CrossEntropyLoss()

        for epoch in range(epochs):
            shuffle(CLEVRER_data)
            for datapoint in CLEVRER_data:
                X, Y = DataOps.process_datapoint(datapoint,tk)
                rep, logits = self(X)
                targets = []
                for y in Y:
                    target = [0.0 for _ in range(n_tokens)]
                    target[y] = 1.0
                    targets.append(target)
                targets = torch.tensor(targets)
                loss = CE(logits,targets)
                loss.backward()
                opt.step()
                opt.zero_grad()
                eval_loss, eval_acc = self.evaluate()
                print (loss.item())
                print ('eval metrics (loss,acc)', eval_loss,eval_acc)

    def evaluate(self):
        
        CE = nn.CrossEntropyLoss()
        compare_counter = 0
        true_counter = 0

        for datapoint in eval_data[:2]:
            X, Y = DataOps.process_datapoint(datapoint,tk)
            rep, logits = self(X)
            targets = []
            for y in Y:
                target = [0.0 for _ in range(n_tokens)]
                target[y] = 1.0
                targets.append(target)
            targets = torch.tensor(targets)
            loss = CE(logits,targets)

            nX = len(X)
            for n in range(nX):
                x, y = [X[n]], Y[n]
                rep, x_logits = self(x)
                dist = F.softmax(x_logits,dim=-1)
                pred = torch.argmax(dist)
                if y == pred.item():
                    true_counter += 1
                    compare_counter += 1
                elif y != pred.item():
                    compare_counter += 1

        return loss, true_counter/float(compare_counter)
    
class Ki2GM(nn.Module):

    def __init__(self):

        super().__init__()
        self.model = TextGen()
        self.inspector = KInspector()

    def forward(self,X):

        rep, logits = self.model(X)
        return rep, logits

    def train(self):

        opt = torch.optim.AdamW(self.parameters())
        CE = nn.CrossEntropyLoss()

        for epoch in range(epochs):
            shuffle(CLEVRER_data)
            for datapoint in CLEVRER_data:
                X, Y = DataOps.process_datapoint(datapoint,tk)
                rep, logits = self.model(X)
                targets = []
                for y in Y:
                    target = [0.0 for _ in range(n_tokens)]
                    target[y] = 1.0
                    targets.append(target)
                targets = torch.tensor(targets)
                gen_loss = CE(logits,targets)
                KG = datapoint[0][1]['KG']
                G, nG = GraphOps.process_graph(KG)
                kg_a = self.inspector(G,nG,rep)
                gt = torch.zeros(max_nodes,max_nodes)
                for i in range(nG):
                    for j in range(nG):
                        gt[i][j] = G[i][j]
                kg_loss = torch.sum(torch.pow(gt-kg_a,2))
                loss = kg_loss + gen_loss
                loss.backward()
                opt.step()
                opt.zero_grad()
                eval_loss,eval_acc = self.model.evaluate()
                print (loss.item())
                print ('eval metrics (loss,acc)', eval_loss,eval_acc)

if __name__ == '__main__':

    max_nodes = None
    kge_size = 96
    kgn_layers = 2

    c_size = 0
    n_tokens = None
    e_size = 96
    h_size = 96
    n_layers = 2
    epochs = 100

    CLEVRER_data = fileIO.read_pickle('CLEVRER_data.pkl')
    eval_data = CLEVRER_data[-int(0.2*len(CLEVRER_data)):]
    tk = TextTokenizer(CLEVRER_data)
    n_tokens = tk.n_tokens
    #model = TextGen()
    #model.train()

    nGs = [GraphOps.process_graph(item[0][1]['KG'])[1] for item in CLEVRER_data]
    max_nodes = max(nGs)
    #model = KInspector()
    #model.train()

    model = Ki2GM()
    model.train()